Skip to content

Safetensor metadata mismatch fix in Mcore export#1422

Open
jinhangchoi wants to merge 7 commits into
NVIDIA:mainfrom
jinhangchoi:jinhangc/safetensor-metadata-fix
Open

Safetensor metadata mismatch fix in Mcore export#1422
jinhangchoi wants to merge 7 commits into
NVIDIA:mainfrom
jinhangchoi:jinhangc/safetensor-metadata-fix

Conversation

@jinhangchoi

@jinhangchoi jinhangchoi commented May 9, 2026

Copy link
Copy Markdown

What does this PR do?

Type of change: Bug fix

In MCore export, shard metadata (*.json) and shard weights (*.safetensors) are produced from mutable shard maps and can be generated from different views of the same dict. In real Nemotron 3 Ultra PTQ runs, I observed MTP-related drift where metadata and shard contents were not aligned. This is plausible because MTP is stage-local (typically last-stage only), so per-rank shard contents are intentionally asymmetric.

The exact mutation interleaving is hard to prove from this code path alone, but the current implementation reads mutable shard maps across separate write steps, making metadata/weights consistency timing-sensitive. The issue is most visible with PP>1, where staggered per-shard writes widen the timing window between metadata and tensor-file generation.

This PR makes shard serialization deterministic in both paths:

  • take a per-shard snapshot once,
  • write .safetensors from that snapshot,
  • write per-shard .json from the same snapshot.

Apply this consistently to:

  • save_safetensors_by_layer_index
  • save_safetensors
    This guarantees shard JSON and shard safetensors cannot diverge due to late dict mutations.

Usage

N/A

Testing

  • Reproduced in Nemotron 3 Ultra MCore PTQ export under PP>1 (multi-stage pipeline parallelism).

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A
  • Did you write any new necessary tests?: ❌
  • Did you update Changelog?: ❌
  • Did you get Claude approval on this PR?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • Bug Fixes
    • Model export now writes each binary shard immediately and uses a frozen per-shard/per-layer snapshot when generating its metadata. This prevents mismatches between shard contents and metadata, improving export consistency and reliability.

@jinhangchoi jinhangchoi requested a review from a team as a code owner May 9, 2026 00:03
@jinhangchoi jinhangchoi requested a review from meenchen May 9, 2026 00:03
@copy-pr-bot

copy-pr-bot Bot commented May 9, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai

coderabbitai Bot commented May 9, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Shard .safetensors writes are moved to occur immediately after each shard's filename is computed (in both save_safetensors and save_safetensors_by_layer_index); per-shard weight_map/total_size computation and .json metadata writing now happen after the shard file save, and the previous post-metadata save_file calls were removed.

Changes

Safetensors Write Order

Layer / File(s) Summary
Per-part safetensors-before-metadata
modelopt/torch/export/plugins/mcore_custom.py
In save_safetensors, each per-part .safetensors shard is written immediately after ckpt_filename is computed; the later save_file after metadata JSON was removed. weight_map/total_size and per-part .json are written afterward.
Per-layer safetensors-before-metadata
modelopt/torch/export/plugins/mcore_custom.py
In save_safetensors_by_layer_index, each layer .safetensors shard is written immediately after ckpt_filename is computed; the trailing post-metadata save_file was removed. weight_map/layer_total_size and per-layer .json are written after the shard save.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~8 minutes

🚥 Pre-merge checks | ✅ 6
✅ Passed checks (6 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Safetensor metadata mismatch fix in Mcore export' directly and accurately describes the main change: fixing a timing-sensitive bug where shard metadata and weights diverge due to mutable shard map views. The title is concise, specific, and clearly communicates the primary intent.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed No security anti-patterns detected. Changes add snapshot freezing without unsafe deserialization, remote code loading, eval/exec, or new dependencies.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/export/plugins/mcore_custom.py (1)

308-320: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Freeze layer_state_dict once to fully eliminate metadata/file drift.

This reorder helps, but you still read a live mutable dict twice. If layer_state_dict changes after save_file(...) and before the metadata loop, .json can diverge from the written .safetensors.

Proposed hardening
     for layer_index, layer_state_dict in layer_state_dicts.items():
         filename = name_template.format(layer_index, total_layers)
         meta_filename = filename + ".json"
         ckpt_filename = filename + ".safetensors"

+        # Freeze key->tensor mapping used by both outputs.
+        frozen_layer_state_dict = dict(layer_state_dict)
+
         # Write safetensors first, then build the per-layer meta JSON from the same dict.
         # Order matters: any late mutations to layer_state_dict (e.g. MTP tensors added after
         # the dict was first constructed) must be captured by both files.  Writing safetensors
         # first ensures the JSON is always consistent with what is physically on disk.
-        save_file(layer_state_dict, save_directory + "/" + ckpt_filename, metadata={"format": "pt"})
+        save_file(
+            frozen_layer_state_dict,
+            save_directory + "/" + ckpt_filename,
+            metadata={"format": "pt"},
+        )

         weight_map = {}
         layer_total_size = 0
-        for key, val in layer_state_dict.items():
+        for key, val in frozen_layer_state_dict.items():
             tensor_size = val.numel() * val.element_size()
             layer_total_size += tensor_size
             weight_map[key] = ckpt_filename
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/export/plugins/mcore_custom.py` around lines 308 - 320,
layer_state_dict is mutated after being written which can cause metadata/file
drift; snapshot it and use that immutable copy for both the safetensors write
and the metadata loop. Specifically, create a frozen copy of layer_state_dict
(e.g., snapshot = dict(layer_state_dict)) and pass snapshot to save_file(...)
and iterate snapshot.items() when building weight_map/layer_total_size so
save_file, weight_map, and layer_total_size are computed from the exact same
data; reference save_file, layer_state_dict, weight_map, ckpt_filename in your
change.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Outside diff comments:
In `@modelopt/torch/export/plugins/mcore_custom.py`:
- Around line 308-320: layer_state_dict is mutated after being written which can
cause metadata/file drift; snapshot it and use that immutable copy for both the
safetensors write and the metadata loop. Specifically, create a frozen copy of
layer_state_dict (e.g., snapshot = dict(layer_state_dict)) and pass snapshot to
save_file(...) and iterate snapshot.items() when building
weight_map/layer_total_size so save_file, weight_map, and layer_total_size are
computed from the exact same data; reference save_file, layer_state_dict,
weight_map, ckpt_filename in your change.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 8c7dee09-3fac-42b5-92ed-6c95d8e50462

📥 Commits

Reviewing files that changed from the base of the PR and between 1d796f9 and a08d952.

📒 Files selected for processing (1)
  • modelopt/torch/export/plugins/mcore_custom.py

@jenchen13 jenchen13 self-requested a review May 11, 2026 14:16
Comment thread modelopt/torch/export/plugins/mcore_custom.py Outdated
@jinhangchoi jinhangchoi force-pushed the jinhangc/safetensor-metadata-fix branch from a08d952 to 693a118 Compare May 18, 2026 00:06

@shengliangxu shengliangxu left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@codecov

codecov Bot commented May 18, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 50.00000% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.01%. Comparing base (54ce4e0) to head (799bce3).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/export/plugins/mcore_custom.py 50.00% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1422      +/-   ##
==========================================
- Coverage   76.27%   76.01%   -0.26%     
==========================================
  Files         489      489              
  Lines       54415    54417       +2     
==========================================
- Hits        41504    41365     -139     
- Misses      12911    13052     +141     
Flag Coverage Δ
examples 31.72% <0.00%> (-10.28%) ⬇️
gpu 58.41% <50.00%> (-1.46%) ⬇️
regression 14.88% <0.00%> (-0.25%) ⬇️
unit 54.05% <50.00%> (+0.04%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@jinhangchoi jinhangchoi force-pushed the jinhangc/safetensor-metadata-fix branch from 19b366a to 8951edb Compare June 4, 2026 23:31

@cjluo-nv cjluo-nv left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Small reorder (write safetensors shard before its per-layer JSON) that mitigates a real symptom seen in MCore PTQ on Nemotron, but I'm not fully convinced it addresses the root cause and there are gaps:

  • Root cause story is unclear from the code path I can see. In unified_export_megatron.py::save_pretrained, _get_mtp_state_dict() is invoked synchronously before save_safetensors_by_layer_index, and the result is merged into layer_state_dicts[num_layers] before the barrier. There's no obvious place in the current code where MTP state is added between the two writes inside the per-layer loop. If the real bug is that something is still mutating layer_state_dict between json.dump and save_file (or vice versa), then this reorder only flips which side wins — after this PR, a late mutation will cause the JSON to advertise keys that were never written to the .safetensors shard, which is arguably worse than the original direction. A snapshot-then-write (e.g. weight_map/sizes computed once into local vars, then write safetensors and JSON from those, with no further dict reads) would actually be race-safe. Worth the author confirming what concretely mutates layer_state_dict after save_file was originally called.
  • Sibling function not updated. save_safetensors (just above, used for the non-per-layer path) has the exact same JSON-then-safetensors ordering. If the ordering matters here, it presumably matters there too — please update both for consistency, or document why only the per-layer path is affected.
  • No test. The author checked "Did you write any new necessary tests? ❌". A regression test that exercises the MTP-included shard and asserts that every key in the per-layer index JSON is actually present in the corresponding .safetensors would lock this down.
  • No CHANGELOG entry (also unchecked).

@jinhangchoi

Copy link
Copy Markdown
Author
  • Updated PR description, and add save_safetensors update.
  • Meanwhile, regression test is a little bit difficult to design. Need some brainstrom.
  • As bugfix, I don't think we need CHANGELOG here

@cjluo-nv cjluo-nv left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Re-review focused on whether cjluo-nv's first two pieces of feedback have been addressed.

  • Feedback #1 — root cause / snapshot-then-write: NOT addressed. 💬 Author updated PR description and reordered the writes, but the underlying concern was that a bare reorder only flips which side "wins" a race — if layer_state_dict is mutated between save_file and the metadata loop, the JSON will now advertise keys that were never written to the .safetensors shard (arguably worse than the original direction). The race-safe fix suggested by both cjluo-nv and CodeRabbit was to snapshot the dict once (e.g. frozen = dict(layer_state_dict)) and drive both save_file and the weight_map loop from that frozen view. The current diff at mcore_custom.py:308-329 still calls save_file(layer_state_dict, ...) and then iterates the live layer_state_dict.items() — same in save_safetensors at lines 254-268 with tensors. So if the author's hypothesis about late mutation is correct, the consistency hole is still open; if it's wrong, it'd be good to hear what concretely was observed to drift on Nemotron 3 Ultra. Worth a human reviewer pushing on this before merge.

  • Feedback #2 — sibling save_safetensors not updated: addressed. 💬 The new commit applies the same reorder to save_safetensors (lines 254-256), so both paths are consistent. (The snapshot caveat from #1 applies to both, though.)

For completeness on the other items the author replied to: regression test is still missing (author says it's hard to design), and CHANGELOG was declined as "bugfix doesn't need it" — leaving those as-is for the human owner to weigh in on.

@jinhangchoi

Copy link
Copy Markdown
Author

I agree snapshot-then-write is the generic race-safe pattern if in-loop mutation is expected. In this PR, the observed MCore issue is better characterized as an ordering/coherence problem: weak ordering guarantees between shard data and shard metadata generation, not a proven same-process mutation between save_file and the subsequent metadata loop.

Under PP>1, MTP is stage-local (typically last-stage only), and staggered shard emission makes this ordering sensitivity more visible. _get_mtp_state_dict() is merged synchronously before save_safetensors_by_layer_index, and we do not currently have an identified concurrent mutator inside the writer loop. The intent of this change is therefore to standardize safetensors -> shard JSON ordering across both writer paths to remove the stale-metadata-first failure mode I observed.
If we want strict future-proof determinism regardless of potential mutators, explicit finalize-then-serialize (or snapshot-at-boundary) is the stronger follow-up.

@cjluo-nv cjluo-nv left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Re-review: sibling-function fix landed, but the core snapshot-then-write concern is still open and worth a human sign-off before merge.

  • 💬 Author replied (PR comment): "snapshot-then-write is the generic race-safe pattern if in-loop mutation is expected... the observed MCore issue is better characterized as an ordering/coherence problem, not a proven same-process mutation between save_file and the metadata loop... explicit finalize-then-serialize is the stronger follow-up." — still warrants human sign-off because the current diff at mcore_custom.py:311-329 (and the matching save_safetensors at 254-268) still passes the live layer_state_dict/tensors to save_file and then re-iterates the same live dict to build weight_map. If the author's "no in-loop mutator" assumption holds, the reorder alone fixes the symptom; if it doesn't, the new ordering means the JSON can advertise keys that aren't in the .safetensors shard — arguably worse than the original direction. A two-line frozen = dict(layer_state_dict) snapshot would close this regardless of which hypothesis is right, and the author already agrees it's the more defensible pattern.
  • 💬 Sibling save_safetensors reorder: addressed in this revision (lines 254-256). Both paths are now consistent.
  • 💬 Author replied: "regression test is a little bit difficult to design. Need some brainstorm." — flagging anyway because the hypothesized failure mode (per-layer JSON listing keys not present in the shard) is straightforward to assert in a unit test by mocking save_file/mutating the dict between writes; a real distributed Nemotron repro isn't required.
  • 💬 Author replied: "as bugfix, I don't think we need CHANGELOG here." — leaving for the human owner to weigh in on per project policy.

@jinhangchoi

Copy link
Copy Markdown
Author

I would like to emphasize this is not a hypothetical issue. It was observed in production export artifacts.

for key, val in layer_state_dict.items():
    weight_map[key] = ckpt_filename
with open(meta_filename, "w") as f:
    json.dump({"metadata": ..., "weight_map": weight_map}, f, indent=4)
save_file(layer_state_dict, ckpt_filename, metadata=...)   # <-- runs LAST

weight_map is built from the dict's keys at iteration time, the JSON is committed, then save_file writes the shard from whatever the dict contains now. Any late mutation lands in the binary but not in the per-layer JSON — and consequently not in rank 0's aggregated model.safetensors.index.json. I saw this in production exports under PP>1: the shards physically contained the keys; the index didn't list them.

The direction in your comment is inverted. By construction of the previous code before fix, the failure is "the JSON lists fewer keys than the shard contains." The shard is the superset, the index is a subset. The comment describes the opposite (JSON listing keys that aren't in the shard), which the previous code before fix can't produce.

Regarding the regression test, I am supportive of adding a unit test along the lines you described, but to be clear about what it covers. This would be a freeze-invariant test, not a true regression for the production failure. The production failure requires PP>1 plus the MTP load path that mutates layer_state_dicts[num_layers] after construction. Neither is exercised by a mocked single-process test. It guards the invariant the freeze is there to guarantee.

The actual end-to-end repro is the original MCore export that motivated this fix. I would like to call this out so we don't oversell what the test asserts.

@jinhangchoi jinhangchoi force-pushed the jinhangc/safetensor-metadata-fix branch from 88b532f to 937eb21 Compare June 5, 2026 17:15

@cjluo-nv cjluo-nv left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Re-review.

Addressed since the previous round:

  • 💬 Snapshot-then-write applied in both save_safetensors (lines 254-256) and save_safetensors_by_layer_index (lines 311-317). frozen_tensors = dict(tensors) / frozen_layer_state_dict = dict(layer_state_dict) is now the single source for both save_file and the weight_map loop, so JSON and shard cannot diverge regardless of late mutations. This closes the main concern from cjluo-nv and CodeRabbit.
  • 💬 Sibling save_safetensors reorder + freeze: applied.
  • 💬 Author replied that a regression test is hard to design — a unit test was nonetheless added in tests/unit/torch/export/test_mcore_save_safetensors.py that mocks save_file and mutates the source dict mid-write to assert the JSON does not pick up the late key. This is exactly the freeze-invariant test that was suggested. 👍

Still warrants human sign-off:

  • The new test passes tmp_path (a pathlib.Path) as save_directory, but the production code does save_directory + "/" + ckpt_filename, which is Path + str and raises TypeError. The test as written looks like it would fail at the first save_file call before reaching the assertions. Either I'm missing something or the test wasn't run — please double-check by either casting to str(tmp_path) in the test or switching the source to os.path.join / Path /. (Same shape issue would also bite save_safetensors if it were ever exercised this way.)
  • 💬 Author replied: "as bugfix, I don't think we need CHANGELOG here." — deferred to human per project policy.

Comment thread tests/unit/torch/export/test_mcore_save_safetensors.py
@jinhangchoi jinhangchoi force-pushed the jinhangc/safetensor-metadata-fix branch from 937eb21 to 352c422 Compare June 5, 2026 17:55
Comment thread tests/unit/torch/export/test_mcore_save_safetensors.py
@jinhangchoi jinhangchoi force-pushed the jinhangc/safetensor-metadata-fix branch from c35051c to ef71f1d Compare June 5, 2026 19:49

@meenchen meenchen left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@meenchen

meenchen commented Jun 5, 2026

Copy link
Copy Markdown
Contributor

/ok to test ef71f1d

Signed-off-by: Jinhang Choi <jinhangc@nvidia.com>
Signed-off-by: Jinhang Choi <jinhangc@nvidia.com>
Signed-off-by: Jinhang Choi <jinhangc@nvidia.com>
Signed-off-by: Jinhang Choi <jinhangc@nvidia.com>
Signed-off-by: Jinhang Choi <jinhangc@nvidia.com>
…hards for PP==2 in test_unified_export_megatron.py

Signed-off-by: Jinhang Choi <jinhangc@nvidia.com>
Signed-off-by: Jinhang Choi <jinhangc@nvidia.com>
@jinhangchoi jinhangchoi force-pushed the jinhangc/safetensor-metadata-fix branch from ef71f1d to 799bce3 Compare June 5, 2026 22:22
@meenchen

meenchen commented Jun 5, 2026

Copy link
Copy Markdown
Contributor

/ok to run 799bce3

@meenchen

meenchen commented Jun 5, 2026

Copy link
Copy Markdown
Contributor

/ok to test 799bce3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants